package com.sql.mysql.sharding.rewrite;

import java.util.List;

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.reflection.MetaObject;

import com.sql.mysql.sharding.annotation.ShardingKeyObject;
import com.sql.mysql.sharding.config.ShardGroup;
import com.sql.mysql.sharding.config.ShardingConfigFind;
import com.sql.mysql.sharding.config.Table;
import com.sql.mysql.sharding.hint.ShardingHintUtils;
import com.sql.mysql.sharding.utils.NumberUtils;

import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.merge.Merge;
import net.sf.jsqlparser.statement.replace.Replace;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.upsert.Upsert;
import net.sf.jsqlparser.util.TablesNamesFinder;

@Slf4j
public class ReWriteEngine {

	// private static void handleSelect(Select statement) {
	// statement.accept(new ShardingStatementVisitor());
	// }
	//
	// private static void handleInsert(Insert statement) {
	// statement.accept(new ShardingStatementVisitor());
	// }
	//
	// private static void handleUpdate(Update statement) {
	// statement.accept(new ShardingStatementVisitor());
	// }
	//
	// private static void handleUpsert(Upsert statement) {
	// statement.accept(new ShardingStatementVisitor());
	// }
	//
	// private static void handleDelete(Delete statement) {
	// statement.accept(new ShardingStatementVisitor());
	// }

	public static void rewrite(MetaObject metaObject, ShardingKeyObject shardingKeyObject, List<ShardGroup> shardGroups,
			BoundSql boundSql) throws Exception {
		try {
			log.debug("enter com.sql.mysql.sharding.rewrite.ReWriteEngine.rewrite");
			// 1)先查找后缀,如果不能转换成long就报错
			long shardingKey = NumberUtils.Object2Long(shardingKeyObject.getValue());
			log.debug("succeed to get sharding key {}", shardingKey);
			// 2)查找table
			Table table = ShardingConfigFind.getTable(shardingKey, shardGroups);
			if (null == table) {
				log.error("no table found for {}", shardingKey);
				throw new Exception("no table found for shardingKey -> " + shardingKey);
			}
			String suffix = table.getSuffix();
			log.debug("succeed to get table, sharding key {} , table's suffix {} ", shardingKey, suffix);
			// 3)获取sql
			String originSql = boundSql.getSql();
			log.debug("before BoundSql {}  ", boundSql);
			log.debug("before  sql {} ", originSql);
			Statement statement = null;
			statement = CCJSqlParserUtil.parse(originSql);//
			// 此处也许可以优化,//TODO
			log.debug("before BoundSql statement {}", statement);
			if (statement instanceof Select //
					|| statement instanceof Insert//
					|| statement instanceof Update//
					|| statement instanceof Upsert//
					|| statement instanceof Delete//
					|| statement instanceof Replace//
					|| statement instanceof Merge//
			) {
				log.debug("valid statement type {}", statement.getClass());
			} else {
				throw new Exception("not valid statement type " + statement);
			}
			// 4)基于TablesNamesFinder来改写table
			log.debug("开始进行sql改写操作");
			ReWriteTablesNamesFinder tablesNamesFinder = new ReWriteTablesNamesFinder(suffix);
			List<String> tableNames = tablesNamesFinder.getTableList(statement);
			log.info("tables {}", tableNames);
			String reWrittenSql = statement.toString();
			{
				// jsqlparser会丢失hint信息
				if (originSql.startsWith(ShardingHintUtils.FORCE_MASTER)
						&& false == reWrittenSql.startsWith(ShardingHintUtils.FORCE_MASTER)) {
					// 补起来
					log.debug("add HintHelper.FORCE_MASTER to reWrittenSql");
					reWrittenSql = ShardingHintUtils.FORCE_MASTER + reWrittenSql;
				}
			}
			log.info("after rewrite: {}", reWrittenSql);
			log.debug("完成sql改写操作");
			// 5)获取rewrite之后的sql,并保存到boundSql里去
			metaObject.setValue("sql", reWrittenSql);
			log.debug("sql 回填 BoundSql 完毕");
		} catch (Exception exception) {
			log.error("error happened {}", exception);
			// 设置为null,后面执行会报错,业务人员就可以看到日志了
			// 一定要设置sql为null,这样执行不下去就报错
			// 因为事务里不是每次都会去取得新的connection
			// 想一想,如果第2次的sql解析失败,但是第1次已经获得了连接
			throw new Exception("error when rewrite sql " + exception);
		}
	}

	public static void main(String[] args) throws JSQLParserException {
		String sql = "INSERT INTO MY_TABLE1 (a) VALUES ((SELECT a from MY_TABLE2 WHERE a = 1))";
		Statement statement = CCJSqlParserUtil.parse(sql);
		TablesNamesFinder tablesNamesFinder = new ReWriteTablesNamesFinder("_0");
		List<String> tableList = tablesNamesFinder.getTableList(statement);
		log.info("sql {}", statement.toString());
		log.info("tableList {}", tableList);
	}
}
